from prepare_hetero_datasets import prepare_hetero_datasets
from load_dataset import load_dataset
from predict import predict
from get_optimizer import get_optimizer
from get_model import create_model
import numpy as np
import torch
import sys
import pickle
from itertools import accumulate
import time
import argparse
import os
os.environ["OMP_NUM_THREADS"] = "1"
sys.path.append('./')


def train(save_path, model_path, **kargs):
    train_data, test_data, n_classes, n_channels = load_dataset(
        kargs["dataset_name"])

    train_loaders, pred_loader_on_train_data, pred_loader_on_test_data \
        = prepare_hetero_datasets(train_data, test_data, n_classes,
                                  n_workers=kargs["n_workers"],
                                  q=kargs["q"],
                                  train_batch_size=kargs["train_batch_size"],
                                  pred_batch_size=kargs["pred_batch_size"])

    model = create_model(kargs["model_name"], n_classes, n_channels)
    if model_path:
        model.load_state_dict(torch.load(model_path))
    model.train()

    optimizer = get_optimizer(optimizer_name=kargs["optimizer_name"],
                              model=model,
                              eta=kargs["eta"],
                              weight_decay=kargs["weight_decay"],
                              train_loaders=train_loaders,
                              gpu_id=kargs["gpu_id"],
                              n_local_iters=kargs["n_local_iters"],
                              n_workers=kargs["n_workers"])

    saved_info = {"train_loss": [], "train_acc": [], "train_grad_norm": [],
                  "test_loss": [], "test_acc": [], "test_grad_norm": [],
                  "args": kargs}

    init_train_loss, _, _ = predict(optimizer.get_model(),
                                    pred_loader_on_train_data,
                                    kargs["weight_decay"])

    s = time.time()
    for i in range(kargs["n_global_iters"]):
        optimizer.update()
        if (i+1) % kargs["save_intvl"] == 0:
            update_time_per_worker = (time.time() - s)/kargs["n_workers"]
            train_loss, train_acc, train_grad_norm = predict(optimizer.get_model(),
                                                             pred_loader_on_train_data,
                                                             kargs["weight_decay"])
            test_loss,  test_acc,  test_grad_norm = predict(optimizer.get_model(),
                                                            pred_loader_on_test_data,
                                                            0.0)
            saved_info["train_loss"].append(train_loss)
            saved_info["train_acc"].append(train_acc)
            saved_info["train_grad_norm"].append(train_grad_norm)
            saved_info["test_loss"].append(test_loss)
            saved_info["test_acc"].append(test_acc)
            saved_info["test_grad_norm"].append(train_grad_norm)
            model.train()
            if train_loss > 2 * init_train_loss:
                print("Learning was stopped")
                break
            print("Iter: {} | Train Loss: {}, Train Acc: {}, Train Grad Norm {}, | Test Loss: {}, Test Acc: {}, Test Grad Norm: {}, Elapsed Time: {}"
                  .format(i+1, train_loss, train_acc, train_grad_norm, test_loss, test_acc, test_grad_norm, time.time() - s))
        with open(save_path, "wb") as f:
            pickle.dump(saved_info, f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Parameters
    parser.add_argument("--eta", type=float, default=0.01)
    # General experimental info
    parser.add_argument("--n_workers", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--dataset_name", type=str, default='cifar10')
    parser.add_argument("--model_name", type=str, default='fc')
    parser.add_argument("--optimizer_name", type=str, default='lsarah')
    parser.add_argument("--exp_name", type=str, default='test')
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--n_global_iters", type=int, default=1000)
    parser.add_argument("--n_local_iters", type=int, default=64)
    parser.add_argument("--train_batch_size", type=int, default=16)
    # homogeneity parameter (refer to the paper)
    parser.add_argument("--q", type=float, default=0.1)
    # Fixed args
    parser.add_argument("--use_pretrain_model", type=int, default=0)
    parser.add_argument("--weight_decay", type=float, default=0.005)
    parser.add_argument("--pred_batch_size", type=int, default=1024)
    parser.add_argument("--save_intvl", type=int, default=10)
    args = parser.parse_args()

    save_dir = os.path.join("results", args.exp_name, args.model_name, args.dataset_name, "homogeneity="+str(args.q),
                            args.optimizer_name + "_K="+str(args.n_local_iters) + "_b="+str(args.train_batch_size), "eta="+str(args.eta))
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seed="+str(args.seed)+".pickle")
    model_path = None
    if args.use_pretrain_model:
        model_path = os.path.join(
            "saved_models", "pretrain", args.model_name, args.dataset_name, "model.pth")

    if torch.cuda.is_available():
        torch.manual_seed(args.gpu_id)
        torch.cuda.manual_seed(args.seed)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
        torch.cuda.set_device(args.gpu_id)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print("GPU Enabled")
    else:
        print("GPU Not Enabled")
    np.random.seed(args.seed)
    train(save_path=save_path, model_path=model_path, **vars(args))
